package com.jstarcraft.ai.jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.linear.Vec;

/**
 * This class provides a method of performing {@link KMeans} clustering when the
 * value of {@code K} is not known. It works by incrementing the value of
 * {@code k} up to some specified maximum, and running a full KMeans for each
 * value. <br>
 * <br>
 * Note, by default this implementation uses a heuristic for the max value of
 * {@code K} that is capped at 100 when using the
 * {@link #cluster(com.jstarcraft.ai.jsat.DataSet) } type methods. <br>
 * <br>
 * When the value of {@code K} is specified, the implementation will simply call
 * the regular KMeans object it was constructed with.
 * 
 * See: Pham, D. T., Dimov, S. S.,&amp;Nguyen, C. D. (2005). <i>Selection of K
 * in K-means clustering</i>. Proceedings of the Institution of Mechanical
 * Engineers, Part C: Journal of Mechanical Engineering Science, 219(1),
 * 103–119. doi:10.1243/095440605X8298
 * 
 * @author Edward Raff
 */
public class KMeansPDN extends KMeans {
    private static final long serialVersionUID = -2358377567814606959L;
    private KMeans kmeans;
    private double[] fKs;

    /**
     * Creates a new clusterer.
     */
    public KMeansPDN() {
        this(new HamerlyKMeans());
    }

    /**
     * Creates a new clustered that uses the specified object to perform clustering
     * for all {@code k}.
     * 
     * @param kmeans the k-means object to use for clustering
     */
    public KMeansPDN(KMeans kmeans) {
        super(kmeans.dm, kmeans.seedSelection, kmeans.rand);
        this.kmeans = kmeans;
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public KMeansPDN(KMeansPDN toCopy) {
        super(toCopy);
        this.kmeans = toCopy.kmeans.clone();
        if (toCopy.fKs != null)
            this.fKs = Arrays.copyOf(toCopy.fKs, toCopy.fKs.length);
    }

    /**
     * Returns the array of {@code f(K)} values generated for the last data set. The
     * value at index {@code i} is the score for cluster {@code i+1}. Smaller values
     * indicate better clusterings.
     * 
     * @return the array of {@code f(K)} values, or {@code null} if no data set has
     *         been clustered
     */
    public double[] getfKs() {
        return fKs;
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return cluster(dataSet, 1, (int) Math.min(Math.max(Math.sqrt(dataSet.size()), 10), 100), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        if (highK == lowK)
            return cluster(dataSet, lowK, parallel, designations);
        else if (highK < lowK)
            throw new IllegalArgumentException("low value of k (" + lowK + ") must be higher than the high value of k(" + highK + ")");
        final int N = dataSet.size();
        final int D = dataSet.getNumNumericalVars();
        fKs = new double[highK - 1];// we HAVE to start from k=2
        fKs[0] = 1.0;// see eq(2)

        int[] bestCluster = new int[N];
        double minFk = lowK == 1 ? 1.0 : Double.POSITIVE_INFINITY;// If our low k is > 1, force the check later to kick in at the first candidate
                                                                  // k by making fK appear Inf

        if (designations == null || designations.length < N)
            designations = new int[N];

        double alphaKprev = 0, S_k_prev = 0;

        // re used every iteration
        List<Vec> curMeans = new ArrayList<>(highK);
        means = new ArrayList<>();// the best set of means
        // pre-compute cache instead of re-computing every time
        List<Double> accelCache = dm.getAccelerationCache(dataSet.getDataVectors(), parallel);

        for (int k = 2; k < highK; k++) {
            curMeans.clear();
            // kmeans objective function result is the same as S_k
            double S_k = cluster(dataSet, accelCache, k, curMeans, designations, true, parallel, true, null);// TODO could add a flag to make approximate S_k an option. Though it dosn't
                                                                                                             // seem to work great on toy problems, might be fine on more realistic data

            double alpha_k;
            if (k == 2)
                alpha_k = 1 - 3.0 / (4 * D); // eq(3a)
            else
                alpha_k = alphaKprev + (1 - alphaKprev) / 6;// eq(3b)

            double fK;// eq(2)
            if (S_k_prev == 0)
                fKs[k - 1] = fK = 1;
            else
                fKs[k - 1] = fK = S_k / (alpha_k * S_k_prev);

            alphaKprev = alpha_k;
            S_k_prev = S_k;

            if (k >= lowK && minFk > fK) {
                System.arraycopy(designations, 0, bestCluster, 0, N);
                minFk = fK;
                means.clear();
                for (Vec mean : curMeans)
                    means.add(mean.clone());
            }
        }

        // contract is we return designations with the data in it if we can, so copy the
        // values back
        System.arraycopy(bestCluster, 0, designations, 0, N);
        return designations;
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, int k, List<Vec> means, int[] assignment, boolean exactTotal, boolean threadpool, boolean returnError, Vec dataPointWeights) {
        return kmeans.cluster(dataSet, accelCache, k, means, assignment, exactTotal, threadpool, returnError, null);
    }

    @Override
    public KMeansPDN clone() {
        return new KMeansPDN(this);
    }

}
